-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TensorRT EP] Fix InferenceSession::Run() not thread-safe issue #19301
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jywu-msft
reviewed
Jan 29, 2024
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Outdated
Show resolved
Hide resolved
jywu-msft
previously approved these changes
Jan 29, 2024
jywu-msft
approved these changes
Jan 30, 2024
YUNQIUGUO
pushed a commit
that referenced
this pull request
Jan 30, 2024
Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: - It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple streams which is not suggested. But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. Therefore, TRT EP needs to call cudaStreamSynchronize() at compute_func() which means to wait until stream has completed all operations to prevent the concurrent github isse: #19275
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently,
TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen:
In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT.
So TRT EP will end up having one trt execution context using multiple streams which is not suggested.
But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed.
Therefore, TRT EP needs to call cudaStreamSynchronize() at compute_func() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above
github isse: #19275